from .base_model import BaseMOMAB
from utils import pareto_front, sherman_morrison, pareto_front_v2
import numpy as np

class MOLB_TS(BaseMOMAB):
    def __init__(self, d, m, num_samples ,delta = 0.01 ,noise = 0.1, lam = None, name = "MOLB_TS"):
        super().__init__(m = m)

        self.noise = noise
        self.delta = delta
        self.num_samples = num_samples
        if lam==None:
            self.lam=1
        else:
            self.lam=lam
        self.d = d
        # calculation
        self.hat_theta = np.zeros(self.d)
        self.V = self.lam*np.eye(d)
        self.Vinv=(1/self.lam)*np.eye(d)
        self.b_t = np.zeros((self.m, self.d))
        self.theta_hat = np.array([np.zeros(d) for _ in range(self.m)])
        
        self.bound_hat = np.sqrt(self.lam) + self.noise * np.sqrt(self.cal_log_1())
        self.bound_tilde = self.bound_hat * np.sqrt(self.cal_log_2())

        self.name = name
        self.settings = {'lambda':self.lam}
        
    def cal_log_1(self):
        return self.d * np.log(self.m * (1 + (self.t-1)/ (self.d*self.lam)) / self.delta )
    
    def cal_log_2(self):
        return 2 * self.d * np.log(2 * self.m * self.num_samples * self.d * self.t)
    
    def sample_simplex(self):
        k = np.random.exponential(scale=1.0, size=self.m)
        return k / sum(k)

    def pareto_front(self, Y):
        K= Y.shape[0]
        pareto_index = [i for i in range(K)]
        for i in range(K):
            for j in pareto_index:
                if np.max(Y[i,:] - Y[j,:]) < 0:
                    pareto_index.remove(i)
                    break    
        return pareto_index

    def select_ac(self, contexts):
        optimistic_mean = np.zeros((contexts.shape[0], self.m))
        for i, theta in enumerate(self.theta_hat):
            # samples = np.array([np.random.multivariate_normal(theta, self.bound_hat **2 * self.Vinv) for _ in range(self.num_samples)])
            samples = np.array([np.random.multivariate_normal(theta, self.Vinv) for _ in range(self.num_samples)])
            estimated = np.max(contexts @ samples.T, axis = 1, keepdims = True)
            optimistic_mean[:,i] = estimated.squeeze()

        # original algorithm
        # p_idx, _ = pareto_front(optimistic_mean)
        # p_idx, _ = pareto_front_v2(optimistic_mean, p_idx)
        # idx = np.random.choice(p_idx)

        # faster selection
        weight_sum = optimistic_mean @ self.sample_simplex()
        idx = np.argmax(weight_sum)

        self.Vinv = sherman_morrison(contexts[idx], self.Vinv)
        return idx
    
    def update(self,reward, context):
        self.b_t += reward.reshape(-1,1) * context
        self.theta_hat = (self.Vinv @ self.b_t.T).T
        self.t+=1
    